{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook compares various Factorization Machines implementations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# I - Factorization Machines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset used here is [MovieLens 100K](https://grouplens.org/datasets/movielens/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python implementation: CPython\n",
      "Python version       : 3.8.5\n",
      "IPython version      : 7.22.0\n",
      "\n",
      "river  : 0.7.0\n",
      "numpy  : 1.20.2\n",
      "pandas : 1.2.4\n",
      "sklearn: 0.24.1\n",
      "xlearn : 0.4.4\n",
      "\n",
      "Compiler    : Clang 10.0.0 \n",
      "OS          : Darwin\n",
      "Release     : 20.3.0\n",
      "Machine     : x86_64\n",
      "Processor   : i386\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark --python --machine --packages river,numpy,pandas,sklearn,xlearn --datename"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LibFM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download and uncompress [`libfm`](http://www.libfm.org/) into the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import tarfile\n",
    "import urllib\n",
    "\n",
    "archive = 'libfm.tar.gz'\n",
    "with urllib.request.urlopen('http://www.libfm.org/libfm-1.42.src.tar.gz') as r, open(archive, 'wb') as f:\n",
    "    shutil.copyfileobj(r, f)\n",
    "\n",
    "tar = tarfile.open(archive, 'r:gz')\n",
    "tar.extractall('.')\n",
    "tar.close()\n",
    "\n",
    "os.remove(archive)\n",
    "\n",
    "libfm_dir = 'libfm-1.42.src'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cd src/libfm; make all\n",
      "g++ -O3 -Wall -c libfm.cpp -o libfm.o\n",
      "g++ -O3 -Wall libfm.o -o ../../bin/libFM\n",
      "g++ -O3 -Wall -c tools/transpose.cpp -o tools/transpose.o\n",
      "g++ -O3 tools/transpose.o -o ../../bin/transpose\n",
      "g++ -O3 -Wall -c tools/convert.cpp -o tools/convert.o\n",
      "g++ -O3 tools/convert.o -o ../../bin/convert\n"
     ]
    }
   ],
   "source": [
    "%%bash -s \"$libfm_dir\"\n",
    "cd $1\n",
    "make all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's prepare our dataset to [`libfm`](http://www.libfm.org/) format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>item</th>\n",
       "      <th>gender</th>\n",
       "      <th>occupation</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>259</td>\n",
       "      <td>255</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>259</td>\n",
       "      <td>286</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>259</td>\n",
       "      <td>298</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>259</td>\n",
       "      <td>185</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>259</td>\n",
       "      <td>173</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  user item gender occupation    y\n",
       "0  259  255      M    student  4.0\n",
       "1  259  286      M    student  4.0\n",
       "2  259  298      M    student  4.0\n",
       "3  259  185      M    student  4.0\n",
       "4  259  173      M    student  4.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from river import datasets\n",
    "\n",
    "def merge_X_y(x, y):\n",
    "    x['y'] = y\n",
    "    return x\n",
    "    \n",
    "ml_100k = [merge_X_y(x, y) for x, y in datasets.MovieLens100K()]\n",
    "ml_100k = pd.DataFrame(ml_100k)\n",
    "ml_100k = ml_100k[['user', 'item', 'gender', 'occupation', 'y']]\n",
    "\n",
    "ml_100k.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perform a 80/20 train test split and one hot encode categorical features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(ml_100k.drop(columns='y'), ml_100k[['y']], test_size=0.2, random_state=17)\n",
    "\n",
    "ohe = OneHotEncoder(handle_unknown='ignore')\n",
    "\n",
    "X_train = ohe.fit_transform(X_train)\n",
    "X_test = ohe.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the data to text files ready to use with [`libfm`](http://www.libfm.org/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from sklearn.datasets import dump_svmlight_file\n",
    "\n",
    "train_file, test_file = 'libfm_train.txt', 'libfm_test.txt'\n",
    "\n",
    "with open(train_file, 'wb') as f:\n",
    "    dump_svmlight_file(X_train, y_train.values.squeeze(), f=f)\n",
    "    \n",
    "with open(test_file, 'wb') as f:\n",
    "    dump_svmlight_file(X_test, np.zeros(len(y_test)), f=f)\n",
    "\n",
    "pred_file = 'libfm_pred.txt'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use [`libfm`](http://www.libfm.org/) to train a model and predict the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------\n",
      "libFM\n",
      "  Version: 1.4.2\n",
      "  Author:  Steffen Rendle, srendle@libfm.org\n",
      "  WWW:     http://www.libfm.org/\n",
      "This program comes with ABSOLUTELY NO WARRANTY; for details see license.txt.\n",
      "This is free software, and you are welcome to redistribute it under certain\n",
      "conditions; for details see license.txt.\n",
      "----------------------------------------------------------------------------\n",
      "Loading train...\t\n",
      "has x = 1\n",
      "has xt = 0\n",
      "num_rows=80000\tnum_values=320000\tnum_features=2622\tmin_target=1\tmax_target=5\n",
      "Loading test... \t\n",
      "has x = 1\n",
      "has xt = 0\n",
      "num_rows=20000\tnum_values=79973\tnum_features=2622\tmin_target=0\tmax_target=0\n",
      "#relations: 0\n",
      "Loading meta data...\t\n",
      "learnrate=0.01\n",
      "learnrates=0.01,0.01,0.01\n",
      "#iterations=1\n",
      "SGD: DON'T FORGET TO SHUFFLE THE ROWS IN TRAINING DATA TO GET THE BEST RESULTS.\n",
      "#Iter=  0\tTrain=0.954638\tTest=3.69095\n",
      "Final\tTrain=0.954638\tTest=3.69095\n"
     ]
    }
   ],
   "source": [
    "%%bash -s \"$libfm_dir\" \"$train_file\" \"$test_file\" \"$pred_file\"\n",
    "cd $1\n",
    "./bin/libFM -task r -dim '1,1,10' -method sgd -iter 1 -learn_rate 0.01 -init_stdev 0.1 -regular '0,0,0' -train ../$2 -test ../$3 -out ../$4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load [`libfm`](http://www.libfm.org/) predictions into memory and compute MAE and RMSE scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LibFM MAE: 0.7619\n",
      "LibFM RMSE: 0.9701\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "libfm_pred = pd.read_csv(pred_file, names=['y'])\n",
    "\n",
    "print(f'LibFM MAE: {mean_absolute_error(y_test, libfm_pred):.4f}')\n",
    "print(f'LibFM RMSE: {mean_squared_error(y_test, libfm_pred) ** 0.5:.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "to_remove = [train_file, test_file, pred_file]\n",
    "\n",
    "for path in to_remove:\n",
    "    os.remove(path)\n",
    "\n",
    "shutil.rmtree(libfm_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## River"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's do the same thing with [River](https://online-ml.github.io/index.html) now!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "River MAE: 0.7598\n",
      "River RMSE: 0.9727\n"
     ]
    }
   ],
   "source": [
    "from river import facto\n",
    "from river import meta\n",
    "from river import optim\n",
    "from river import stream\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(ml_100k.drop(columns='y'),\n",
    "                                                    ml_100k[['y']],\n",
    "                                                    test_size=0.2,\n",
    "                                                    random_state=17)\n",
    "\n",
    "fm_params = {\n",
    "    'n_factors': 10,\n",
    "    'weight_optimizer': optim.SGD(0.01),\n",
    "    'latent_optimizer': optim.SGD(0.01),\n",
    "    'l1_weight': 0.,\n",
    "    'l2_weight': 0.,\n",
    "    'l1_latent': 0.,\n",
    "    'l2_latent': 0.,\n",
    "    'intercept': 0.,\n",
    "    'intercept_lr': 0.01,\n",
    "    'weight_initializer': optim.initializers.Zeros(),\n",
    "    'latent_initializer': optim.initializers.Normal(mu=0., sigma=0.1, seed=85),\n",
    "}\n",
    "\n",
    "model = meta.PredClipper(\n",
    "    regressor=facto.FMRegressor(**fm_params),\n",
    "    y_min=1,\n",
    "    y_max=5\n",
    ")\n",
    "\n",
    "for x, y in stream.iter_pandas(X_train, y_train):\n",
    "    model.learn_one(x, y['y'])\n",
    "\n",
    "river_pred = [model.predict_one(x) for x, _ in stream.iter_pandas(X_test)]\n",
    "\n",
    "print(f'River MAE: {mean_absolute_error(y_test, river_pred):.4f}')\n",
    "print(f'River RMSE: {mean_squared_error(y_test, river_pred) ** 0.5:.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| FM - MovieLens100K | MAE      | RMSE     |\n",
    "|:-------------------|:--------:|:--------:|\n",
    "| LibFM              |  0.7619  |  0.9701  |\n",
    "| River              |  0.7598  |  0.9727  |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# II - Field-aware Factorization Machines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset used here is a 1% subsampling from [Criteo's challenge](https://www.kaggle.com/c/criteo-display-ad-challenge) built by [`libffm`](https://github.com/ycjuan/libffm). Clic [here](https://drive.google.com/uc?export=download&confirm=v1vT&id=1HZX7zSQJy26hY4_PxSlOWz4x7O-tbQjt) to download it from their Google drive, then move it into the working directory. Let's uncompress it now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import zipfile\n",
    "\n",
    "archive = 'libffm_toy.zip'\n",
    "with zipfile.ZipFile(archive, 'r') as zf:\n",
    "    zf.extractall('.')\n",
    "\n",
    "os.remove(archive)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LibFFM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download and uncompress [`libffm`](https://github.com/ycjuan/libffm) into the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "archive = 'libffm.zip'\n",
    "with urllib.request.urlopen('https://github.com/ycjuan/libffm/archive/master.zip') as r, open(archive, 'wb') as f:\n",
    "    shutil.copyfileobj(r, f)\n",
    "\n",
    "with zipfile.ZipFile(archive, 'r') as zf:\n",
    "    zf.extractall('.')\n",
    "\n",
    "os.remove(archive)\n",
    "\n",
    "libffm_dir = 'libffm-master'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compile the tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "g++ -Wall -O3 -std=c++0x -march=native -DUSESSE -c -o timer.o timer.cpp\n",
      "g++ -Wall -O3 -std=c++0x -march=native -DUSESSE -c -o ffm.o ffm.cpp\n",
      "g++ -Wall -O3 -std=c++0x -march=native -DUSESSE -o ffm-train ffm-train.cpp ffm.o timer.o\n",
      "g++ -Wall -O3 -std=c++0x -march=native -DUSESSE -o ffm-predict ffm-predict.cpp ffm.o timer.o\n"
     ]
    }
   ],
   "source": [
    "%%bash -s \"$libffm_dir\"\n",
    "cd $1\n",
    "make"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use [`libffm`](https://github.com/ycjuan/libffm) to train a model and predict the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file = 'libffm_toy/criteo.tr.r100.gbdt0.ffm'\n",
    "test_file = 'libffm_toy/criteo.va.r100.gbdt0.ffm'\n",
    "model_file = 'libffm_model'\n",
    "pred_file = 'libffm_pred'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First check if the text file has already been converted to binary format (0.0 seconds)\n",
      "Binary file NOT found. Convert text file to binary file (2.0 seconds)\n",
      "iter   tr_logloss      tr_time\n",
      "   1      0.61859          9.1\n",
      "logloss = 0.52888\n"
     ]
    }
   ],
   "source": [
    "%%bash -s \"$train_file\" \"$test_file\" \"$model_file\" \"$pred_file\"\n",
    "cd libffm-master\n",
    "./ffm-train -l 0.0 -k 10 -t 1 -r 0.01 -s 4 ../$1 ../$3\n",
    "./ffm-predict ../$2 ../$3 ../$4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load [`libffm`](https://github.com/ycjuan/libffm) predictions into memory and compute Accuracy, Log loss and ROC AUC scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LibFFM Accuracy: 0.7485\n",
      "LibFFM Log loss: 0.5289\n",
      "LibFFM ROC AUC: 0.6910\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "y_test = pd.read_csv(test_file, sep=' ', names=['y_true'] + [i for i in range(39)], usecols=['y_true'])\n",
    "libffm_pred = pd.read_csv(pred_file, names=['y_hat'])\n",
    "\n",
    "print(f'LibFFM Accuracy: {accuracy_score(y_test, libffm_pred > .5):.4f}')\n",
    "print(f'LibFFM Log loss: {log_loss(y_test, libffm_pred):.4f}')\n",
    "print(f'LibFFM ROC AUC: {roc_auc_score(y_test, libffm_pred):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.remove(model_file)\n",
    "os.remove(pred_file)\n",
    "\n",
    "shutil.rmtree(libffm_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## xLearn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use [`xlearn`](https://xlearn-doc.readthedocs.io/en/latest/index.html) to train a model and predict the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xlearn as xl\n",
    "\n",
    "xlearn_model = xl.create_ffm()\n",
    "xlearn_model.setSigmoid()\n",
    "xlearn_model.setTrain(train_file)\n",
    "xlearn_model.disableNorm() # Disable instance-wise normalization\n",
    "\n",
    "xlearn_params = {\n",
    "    'task': 'binary',\n",
    "    'k': 10,\n",
    "    'epoch': 1,\n",
    "    'opt': 'sgd',\n",
    "    'lr': 0.01,\n",
    "    'lambda': 0.0,\n",
    "    'nthread': 4\n",
    "}\n",
    "\n",
    "model_file = 'xlearn_model'\n",
    "pred_file = 'xlearn_pred'\n",
    "\n",
    "xlearn_model.fit(xlearn_params, model_file)\n",
    "\n",
    "xlearn_model.setTest(test_file)\n",
    "xlearn_model.predict('xlearn_model', pred_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load [`xlearn`](https://xlearn-doc.readthedocs.io/en/latest/index.html) predictions into memory and compute Accuracy, Log loss and ROC AUC scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xLearn Accuracy: 0.7642\n",
      "xLearn Log loss: 0.5247\n",
      "xLearn ROC AUC: 0.7401\n"
     ]
    }
   ],
   "source": [
    "xlearn_pred = pd.read_csv(pred_file, names=['y_hat'])\n",
    "\n",
    "print(f'xLearn Accuracy: {accuracy_score(y_test, xlearn_pred > .5):.4f}')\n",
    "print(f'xLearn Log loss: {log_loss(y_test, xlearn_pred):.4f}')\n",
    "print(f'xLearn ROC AUC: {roc_auc_score(y_test, xlearn_pred):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.remove(model_file)\n",
    "os.remove(pred_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## River"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Format data in order to be compatible with [River](https://online-ml.github.io/index.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_criteo_data(filepath):\n",
    "    X = pd.read_csv(filepath, sep=' ', names=['y'] + [str(i) for i in range(39)])\n",
    "    y = X[['y']].copy()\n",
    "    X = X.drop(columns='y').applymap(lambda x: x.split(':')[1])\n",
    "    return X, y\n",
    "\n",
    "X_train, y_train = load_criteo_data(train_file)\n",
    "X_test, y_test = load_criteo_data(test_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use [River](https://online-ml.github.io/index.html) to train a model and predict the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "River Accuracy: 0.7551\n",
      "River Log loss: 0.5134\n",
      "River ROC AUC: 0.7422\n"
     ]
    }
   ],
   "source": [
    "ffm_params = {\n",
    "    'n_factors': 10,\n",
    "    'weight_optimizer': optim.SGD(0.01),\n",
    "    'latent_optimizer': optim.SGD(0.01),\n",
    "    'l1_weight': 0.,\n",
    "    'l2_weight': 0.,\n",
    "    'l1_latent': 0.,\n",
    "    'l2_latent': 0.,\n",
    "    'intercept': 0.,\n",
    "    'intercept_lr': 0.01,\n",
    "    'weight_initializer': optim.initializers.Zeros(),\n",
    "    'latent_initializer': optim.initializers.Normal(mu=0., sigma=0.1, seed=85),\n",
    "}\n",
    "\n",
    "model = facto.FFMClassifier(**ffm_params)\n",
    "\n",
    "for x, y in stream.iter_pandas(X_train, y_train):\n",
    "    model.learn_one(x, y['y'])\n",
    "\n",
    "river_pred = [model.predict_proba_one(x)[True] for x, _ in stream.iter_pandas(X_test)]\n",
    "river_pred = pd.Series(river_pred)\n",
    "\n",
    "print(f'River Accuracy: {accuracy_score(y_test, river_pred > .5):.4f}')\n",
    "print(f'River Log loss: {log_loss(y_test, river_pred):.4f}')\n",
    "print(f'River ROC AUC: {roc_auc_score(y_test, river_pred):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree('libffm_toy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| FFM - Criteo subsampled | Accuracy | Log loss | ROC AUC |\n",
    "|:------------------------|:--------:|:--------:|:-------:|\n",
    "| LibFFM                  |  0.7485  |  0.5289  |  0.6910 |\n",
    "| xLearn                  |  0.7642  |  0.5247  |  0.7401 |\n",
    "| River                   |  0.7551  |  0.5134  |  0.7422 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "river-benchmarks",
   "language": "python",
   "name": "river-benchmarks"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
